{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Changing the base distribution of a flow model\n",
    "\n",
    "This example shows how one can easily change the base distribution with our API.\n",
    "First, let's look at how the normalizing flow can learn a two moons target distribution with a Gaussian distribution as the base."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Q-wJe2vcAHdK"
   },
   "outputs": [],
   "source": [
    "# Import packages\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "import normflows as nf\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "from matplotlib import cm\n",
    "\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setting up a flow model with a 2D Gaussian base distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0liFXnmGAdE5"
   },
   "outputs": [],
   "source": [
    "# Set up model\n",
    "\n",
    "# Define 2D Gaussian base distribution\n",
    "base = nf.distributions.base.DiagGaussian(2)\n",
    "\n",
    "# Define list of flows\n",
    "num_layers = 32\n",
    "flows = []\n",
    "for i in range(num_layers):\n",
    "    # Neural network with two hidden layers having 64 units each\n",
    "    # Last layer is initialized by zeros making training more stable\n",
    "    param_map = nf.nets.MLP([1, 64, 64, 2], init_zeros=True)\n",
    "    # Add flow layer\n",
    "    flows.append(nf.flows.AffineCouplingBlock(param_map))\n",
    "    # Swap dimensions\n",
    "    flows.append(nf.flows.Permute(2, mode='swap'))\n",
    "    \n",
    "# Construct flow model\n",
    "model = nf.NormalizingFlow(base, flows)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0rZb2ORAAfd_"
   },
   "outputs": [],
   "source": [
    "# Move model on GPU if available\n",
    "enable_cuda = True\n",
    "device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "r8QfcFpnAlNM"
   },
   "outputs": [],
   "source": [
    "# Define target distribution\n",
    "target = nf.distributions.TwoMoons()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "9UkDaqTNAorD",
    "outputId": "d0a132bd-16f4-4816-96be-77111484b33f"
   },
   "outputs": [],
   "source": [
    "# Plot target distribution\n",
    "grid_size = 200\n",
    "xx, yy = torch.meshgrid(torch.linspace(-3, 3, grid_size), torch.linspace(-3, 3, grid_size))\n",
    "zz = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2).view(-1, 2)\n",
    "zz = zz.to(device)\n",
    "\n",
    "log_prob = target.log_prob(zz).to('cpu').view(*xx.shape)\n",
    "prob = torch.exp(log_prob)\n",
    "prob[torch.isnan(prob)] = 0\n",
    "\n",
    "plt.figure(figsize=(15, 15))\n",
    "plt.pcolormesh(xx, yy, prob.data.numpy(), cmap='coolwarm')\n",
    "plt.gca().set_aspect('equal', 'box')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "IEdqxg1pAqNO",
    "outputId": "6f6ec280-3d3e-4385-8574-6bedc96619bb"
   },
   "outputs": [],
   "source": [
    "# Plot initial flow distribution\n",
    "model.eval()\n",
    "log_prob = model.log_prob(zz).to('cpu').view(*xx.shape)\n",
    "model.train()\n",
    "prob = torch.exp(log_prob)\n",
    "prob[torch.isnan(prob)] = 0\n",
    "\n",
    "plt.figure(figsize=(15, 15))\n",
    "plt.pcolormesh(xx, yy, prob.data.numpy(), cmap='coolwarm')\n",
    "plt.gca().set_aspect('equal', 'box')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "rsMgRWnkAsuc",
    "outputId": "c630d31a-c22d-40f3-ddbe-ff1037de7f97"
   },
   "outputs": [],
   "source": [
    "# Train model\n",
    "max_iter = 4000\n",
    "num_samples = 2 ** 9\n",
    "show_iter = 500\n",
    "\n",
    "\n",
    "loss_hist = np.array([])\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-5)\n",
    "\n",
    "for it in tqdm(range(max_iter)):\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    # Get training samples\n",
    "    x = target.sample(num_samples).to(device)\n",
    "    \n",
    "    # Compute loss\n",
    "    loss = model.forward_kld(x)\n",
    "    \n",
    "    # Do backprop and optimizer step\n",
    "    if ~(torch.isnan(loss) | torch.isinf(loss)):\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "    # Log loss\n",
    "    loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())\n",
    "    \n",
    "    # Plot learned distribution\n",
    "    if (it + 1) % show_iter == 0:\n",
    "        model.eval()\n",
    "        log_prob = model.log_prob(zz)\n",
    "        model.train()\n",
    "        prob = torch.exp(log_prob.to('cpu').view(*xx.shape))\n",
    "        prob[torch.isnan(prob)] = 0\n",
    "\n",
    "        plt.figure(figsize=(15, 15))\n",
    "        plt.pcolormesh(xx, yy, prob.data.numpy(), cmap='coolwarm')\n",
    "        plt.gca().set_aspect('equal', 'box')\n",
    "        plt.show()\n",
    "\n",
    "# Plot loss\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.plot(loss_hist, label='loss')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualizing the learned distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "d7SthvmqAvb9"
   },
   "outputs": [],
   "source": [
    "# Plot target distribution\n",
    "f, ax = plt.subplots(1, 2, sharey=True, figsize=(15, 7))\n",
    "\n",
    "log_prob = target.log_prob(zz).to('cpu').view(*xx.shape)\n",
    "prob = torch.exp(log_prob)\n",
    "prob[torch.isnan(prob)] = 0\n",
    "\n",
    "ax[0].pcolormesh(xx, yy, prob.data.numpy(), cmap='coolwarm')\n",
    "\n",
    "ax[0].set_aspect('equal', 'box')\n",
    "ax[0].set_axis_off()\n",
    "ax[0].set_title('Target', fontsize=24)\n",
    "\n",
    "# Plot learned distribution\n",
    "model.eval()\n",
    "log_prob = model.log_prob(zz).to('cpu').view(*xx.shape)\n",
    "model.train()\n",
    "prob = torch.exp(log_prob)\n",
    "prob[torch.isnan(prob)] = 0\n",
    "\n",
    "ax[1].pcolormesh(xx, yy, prob.data.numpy(), cmap='coolwarm')\n",
    "\n",
    "ax[1].set_aspect('equal', 'box')\n",
    "ax[1].set_axis_off()\n",
    "ax[1].set_title('Real NVP', fontsize=24)\n",
    "\n",
    "plt.subplots_adjust(wspace=0.1)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FrNwgl4PDwZ2"
   },
   "source": [
    "Notice there is a bridge between the two modes of the learned target.\n",
    "This is not a big deal usually since the bridge is really thin, and going to higher dimensional space will make it expoentially unlike to have samples within the bridge.\n",
    "However, we can see the shape of each mode is also a bit distorted.\n",
    "So it would be nice to get rid of the bridge. Now let's try to use a Gaussian mixture distribution as our base distribution, instead of a single Gaussian."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Use a Gaussian mixture model as the base instead"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AUsgWGXeAxhN"
   },
   "outputs": [],
   "source": [
    "# Set up model\n",
    "\n",
    "# Define a mixture of Gaussians with 2 modes.\n",
    "base = nf.distributions.base.GaussianMixture(2,2, loc=[[-2,0],[2,0]],scale=[[0.3,0.3],[0.3,0.3]])\n",
    "\n",
    "# Define list of flows\n",
    "num_layers = 32\n",
    "flows = []\n",
    "for i in range(num_layers):\n",
    "    # Neural network with two hidden layers having 64 units each\n",
    "    # Last layer is initialized by zeros making training more stable\n",
    "    param_map = nf.nets.MLP([1, 64, 64, 2], init_zeros=True)\n",
    "    # Add flow layer\n",
    "    flows.append(nf.flows.AffineCouplingBlock(param_map))\n",
    "    # Swap dimensions\n",
    "    flows.append(nf.flows.Permute(2, mode='swap'))\n",
    "    \n",
    "# Construct flow model\n",
    "model = nf.NormalizingFlow(base, flows).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "ZwgP8TiKB-Ej",
    "outputId": "03438b3a-ac28-422a-b84e-1ce9cb9a7f46"
   },
   "outputs": [],
   "source": [
    "# Plot initial flow distribution\n",
    "model.eval()\n",
    "log_prob = model.log_prob(zz).to('cpu').view(*xx.shape)\n",
    "model.train()\n",
    "prob = torch.exp(log_prob)\n",
    "prob[torch.isnan(prob)] = 0\n",
    "\n",
    "plt.figure(figsize=(15, 15))\n",
    "plt.pcolormesh(xx, yy, prob.data.numpy(), cmap='coolwarm')\n",
    "plt.gca().set_aspect('equal', 'box')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the new model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "3TOELsELBhf1",
    "outputId": "b17c08d0-1b20-4db0-f830-310e9700587e"
   },
   "outputs": [],
   "source": [
    "# Train model\n",
    "max_iter = 4000\n",
    "num_samples = 2 ** 9\n",
    "show_iter = 500\n",
    "\n",
    "\n",
    "loss_hist = np.array([])\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-5)\n",
    "\n",
    "for it in tqdm(range(max_iter)):\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    # Get training samples\n",
    "    x = target.sample(num_samples).to(device)\n",
    "    \n",
    "    # Compute loss\n",
    "    loss = model.forward_kld(x)\n",
    "    \n",
    "    # Do backprop and optimizer step\n",
    "    if ~(torch.isnan(loss) | torch.isinf(loss)):\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "    # Log loss\n",
    "    loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())\n",
    "    \n",
    "    # Plot learned distribution\n",
    "    if (it + 1) % show_iter == 0:\n",
    "        model.eval()\n",
    "        log_prob = model.log_prob(zz)\n",
    "        model.train()\n",
    "        prob = torch.exp(log_prob.to('cpu').view(*xx.shape))\n",
    "        prob[torch.isnan(prob)] = 0\n",
    "\n",
    "        plt.figure(figsize=(15, 15))\n",
    "        plt.pcolormesh(xx, yy, prob.data.numpy(), cmap='coolwarm')\n",
    "        plt.gca().set_aspect('equal', 'box')\n",
    "        plt.show()\n",
    "\n",
    "# Plot loss\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.plot(loss_hist, label='loss')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3fxV_64sFAp1"
   },
   "source": [
    "Now the modes are in better shape! And there is no bridge between the two modes!"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
